Deep Neural Networks as Gaussian Processes

Let's consider a regression problem. We want to find a function $f$ that map vectors $x\in\mathbb{R}^{d_{in}}$ to values $y\in\mathbb{R}$. For this problem, the most common probabilistic model to use is:

$$ \begin{aligned} y_i &= f(x_i) + \epsilon_i\\ \epsilon_i &\sim \mathcal{N}(0,\sigma_\epsilon^2)\\ \epsilon_i,\epsilon_j & \quad \text{independent } \forall i,j \end{aligned} $$

This model assumes that a pair of outputs $y_i$, $y_j$ are independent given $x_i,x_j$ and the function $f$. Since the data is normally distributed, we can write this explicitly in the following manner:

$$ p\left(\begin{pmatrix} y_i \\y_j \end{pmatrix} \mid \begin{pmatrix} x_i \\x_j \end{pmatrix}, f \right) \sim \mathcal{N}\left( \begin{pmatrix} f(x_i) \\f(x_j) \end{pmatrix}, \sigma_\epsilon^2 I \right) \quad (1) $$

Given that probabilistic model and some observed values $\mathcal{D} = \{x_i,y_i\}_{i=1}^N$, we want to make predictions. Making predictions, in Bayesian language means to find a predictive posterior distribution: $ p(y^\star\mid x^\star, \mathcal{D})$. Notice that with this predictive distribution the problem is solved, we just have to plug values $x^\star$ on the left side and we will have the distribution that the predictions follow. With this distribution we can compute e.g. the mean to use it as an estimator of the outcome.

To arrive to this predictive posterior there are two paths: I call these paths the standard Bayesian approach and the $\mathcal{GP}$ approach. We will show these two paths for the general case and for the particular case of the linear regression model, the Neural Network regression model and the $\mathcal{GP}$ regression model. We will see that the $\mathcal{GP}$ approach applied to the Neural Network regression model leads to the NNGP formulation of the paper.

The standard Bayesian approach

The standard Bayesian approach consists of finding first the distribution of $f$ given the data ($p(f\mid\mathcal{D})$, this distribution is called the posterior distribution) and then integrate out across all possible $f$ choices:

$$ p(y^\star \mid x^\star,\mathcal{D}) = \int p(y^\star \mid x^\star, f) p(f \mid \mathcal{D}) df \quad (2) $$

To compute the posterior distribution $p(f\mid \mathcal{D})$ we rely on the Bayes' formula:

$$ p(f \mid \mathcal{D}) = \frac{p(y \mid X, f)p(f\mid X)}{p(y \mid X)} = \frac{p(y \mid X, f)p(f\mid X)}{\int p(y \mid X, f)p(f\mid X) df} \quad (3) $$

The first term on the upper part of the formula is called the likelihood. We can compute this term exactly since it corresponds to equation $(1)$. The second term in the upper part is called the prior. This is something we have to make up to make the whole formula works. The integral in the lower part of the fraction is called the marginal likelihood. This is the difficult part of the formula that makes that in many times computing the posterior is intractable (NP-hard).

Bayesian linear regression (standard approach)

As an example consider the linear regression. In the linear regression we restrict to linear functions $f$: $f(x) = w^t x + w_0$. To set a prior over $f$ we just have to set a prior over the weights and the bias term. A common approach is to set a normal prior with zero mean and isotropic (diagonal covariance) $w\sim \mathcal{N}(0,\sigma_w^2 I)$. With this prior the marginal likelihood (integral in the lower part of equation $(3)$) is tractable:

$$ p(y \mid X) = \int p(y \mid X, w) p(w) dw = \int \mathcal{N}(y \mid Xw +w_0, \sigma_\epsilon^2 I) \mathcal{N}(w\mid 0, \sigma_w^2 I) dw = \mathcal{N}\left(y \mid 0, \sigma_\epsilon^2 I + \sigma_w^2X'X'^\top\right) \quad (4) $$

($X'$ is $X$ with a ones column)

Where we use the tricks of the trade of normal distributions (see e.g. Bishop 2.3.3. section).

Since for the linear regression case, all terms in equation $(2)$ are normals, it turns out that the posterior $ p(w,b\mid \mathcal{D})$ is also normal. In addition, since the posterior is normal, the predictive posterior (which is our final goal!) is also normal. We write down these equations for completeness. (For details of the derivation, again Bishop 2.3.3 is a good place to go).

$$ \begin{aligned} p(w \mid \mathcal{D}) &\sim \mathcal{N}\left( \left(\tfrac{\sigma_\epsilon^2}{\sigma_w^2}I + X'^\top X'\right)^{-1}X'^\top y, \left(\tfrac{1}{\sigma_w^2}I + \tfrac{1}{\sigma_\epsilon^2} X'^\top X'\right)^{-1} \right) \\ p(y^\star \mid x^\star, \mathcal{D}) &\sim \mathcal{N}\Big(x'^{\star\top}\left(\tfrac{\sigma_\epsilon^2}{\sigma_w^2}I + X'^\top X'\right)^{-1}X'^\top y , \\ &\quad\quad \sigma_\epsilon^2 + \sigma_\epsilon^2x'^{\star\top} \left(\tfrac{\sigma_\epsilon^2}{\sigma_w^2}I + X'^\top X'\right)^{-1}x'^{\star} \Big) \quad (5) \end{aligned} $$

(We use $x'^\star$ as the vector $x^\star$ with an extra 1)

Neural networks (standard approach)

Unfortunatelly, this "easy" derivation can't be done for non linear choices of functions $f$. This is because we can't apply Bishop 2.3.3. trick on equation $(4)$ since the relation between $x$ and the weights $w$ is not linear. Consider the the case of Neural Networks. Feed forward fully connected neural networks are models like the following (for a 2-hidden layers network):

$$ \begin{aligned} x^1 &= \phi(W^0x + b^0)\\ z^1 &= W^1x^1 + b^1 = W^1\phi(W^0x + b^0)+ b^1 \\ x^2 &= \phi(z^1) = \phi(W^1x^1 + b^1) = \phi(W^1\phi(W^0x + b^0)+ b^1)\\ z^2 &= W^2x^2+b^2 = W^2\phi(W^1x^1 + b^1) +b^2 = W^2 \phi(W^1\phi(W^0x + b^0)+ b^1) +b^2 \end{aligned} $$

We call $\omega$ to the set of all weights of the network. In this particular case $\omega = \{W^0,b^0,W^1,b^1,W^2,b^2\}$. $\phi$ is the non-linear activation function and the final mapping function is $f_\omega(x) = z^2$. We can set also a zero mean isotropic normal distribution for the prior of the weights $\omega\sim \mathcal{N}(0,\sigma_\omega^2 I)$.

For this model, the marginal likelihood equation will be: $$ p(y \mid X) = \int p(y \mid f_\omega(X), \sigma_\epsilon^2 I ) \mathcal{N}(\omega\mid 0, \sigma_\omega^2I) d\omega $$ Which as we said, this equation do not have a close form solution. Therefore, to compute the posterior of the weights of this neural network, we have to rely on e.g. Variational, Expectation Propagation or Monte-Carlo methods. These approaches will give us a distribution of the weights of the network $\omega$ that is close to the true posterior of the weights ($p(\omega\mid \mathcal{D}$). With this approximation we will go to equation $(2)$ to obtain the predictive posterior.

The $\mathcal{GP}$ approach

There is another path to reach the predictive posterior that does not involve finding the posterior of $f$ given the data ($p(f\mid \mathcal{D})$). This approach consists of (a) building the joint likelihood, (b) integrate out $f$ and (c) get the predictive posterior from that joint marginal distribution. This is the approach that $\mathcal{GP}$s follow.

First we notice that for some observed data $\mathcal{D}=(X,y)$ and some test data $(y^\star,x^\star)$, if we are given $f$, our modeling assumption $(1)$ implies that:

$$ p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix}, f \right) \sim \mathcal{N}\left( \begin{pmatrix} f(X) \\f(x^{\star}) \end{pmatrix}, \sigma_\epsilon^2 I \right) $$

We call to this equation the train-test joint likelihood. With this joint likelihood we can try to integrate out $f$ to get the joint marginal distribution:

$$ p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix} \right) = \int p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix}, f \right) p(f\mid X) df =\int \mathcal{N}\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} f(X) \\f(x^\star) \end{pmatrix}, \sigma_\epsilon^2 I \right) p(f\mid X, x^\star) df \quad (6) $$

If we manage to have the joint marginal, it sometimes happen that the predictive posterior is easy to find. This is the case if, for example, the joint marginal is gaussian. If the joint marginal is Gaussian we can obtain the predictive distribution easily (see wikipedia).

There are two models where we can derive easily the predictive posterior distribution using the $\mathcal{GP}$ approach. The linear model and the $\mathcal{GP}$.

Bayesian linear regression ($\mathcal{GP}$ approach)

The derivation of the joint marginal for the linear model is exactly similar to derivation in equation $(4)$. This leads to:

$$ p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix} \right) \sim \mathcal{N}\left( 0, \sigma_\epsilon^2 I + \sigma_w^2 \begin{pmatrix} X'X'^\top & X'x'^\star \\ x'^{\star\top} X'^\top & x'^{\star\top} x'^\star \end{pmatrix}\right) $$

With this joint marginal we can apply the wikipedia formula to obtain the predictive posterior:

$$ \begin{aligned} p\left(y^\star \mid \overbrace{y,X}^{\mathcal{D}},x^\star\right) &\sim \mathcal{N}\Big(x'^{\star\top}X'^\top \left(\sigma_w^2 X'X'^\top + \sigma_\epsilon^2I\right)^{-1}y, \\ &\quad\quad \sigma_\epsilon^2 + x'^{\star\top}x'^{\star} - \sigma_w^2 x'^{\star\top}X'^\top\left( X'X'^\top + \tfrac{\sigma_\epsilon^2}{\sigma_w^2} I\right)^{-1}X'x'^\star \Big) \quad (7) \end{aligned} $$

Believe it or not, this equation is equivalent to equation $(5)$! It is the same mean and covariance matrices! To show this we just have to apply the matrix inversion lemma carefully.

$\mathcal{GP}$ Regression

We can of course retrieve the $\mathcal{GP}$ regression solution using the $\mathcal{GP}$ approach. For $\mathcal{GP}$s, we will set a extrange prior over $f$ ($p(f\mid X,x^\star$)). This prior says that any finite collection of $f(x_1),...,f(x_k)$ values is multivariate-normally distributed. The most common $\mathcal{GP}$ regression approach set the mean of such normal distribution to 0 and the covariance is given by a kernel function $k(\cdot,\cdot)$:

$$ p\begin{pmatrix} f(x_1)\\\vdots\\f(x_k)\end{pmatrix} \sim \mathcal{N}\left(0, \begin{pmatrix} k(x_1,x_1) & ... & k(x_1,x_k) \\ \vdots & \ddots &\vdots \\ k(x_k,x_1) & ... & k(x_k,x_k) \end{pmatrix}\right) $$

With this peculiar prior, the integral that gives the joint marginal of eq. (6) is:

$$ \begin{aligned} p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix} \right) &= \int \mathcal{N}\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} f(X) \\f(x^\star) \end{pmatrix}, \sigma_\epsilon^2 I \right) p(f\mid X,x^\star) df \\ &\sim \mathcal{N}\left(0, \begin{pmatrix}K_{X,X} + \sigma_\epsilon^2 I & K_{X,x^\star} \\ K_{x^\star,X} & k(x^\star,x^\star)\end{pmatrix}\right) \end{aligned} $$

With this joint marginal we can apply the wikipedia formula to obtain the predictive posterior of the $\mathcal{GP}$ regression:

$$ \begin{aligned} p(y^\star \mid y,X,x^\star) &\sim \mathcal{N}\Big(K_{x^\star,X}\left(K_{XX}+\sigma_\epsilon^2I\right)^{-1}y\\ &\quad\quad \sigma_\epsilon^2 + k(x^\star,x^\star) - K_{x^\star,X}\Big(\left(K_{XX}+\sigma_\epsilon^2I\right)^{-1} K_{X,x^\star} \Big) \end{aligned} $$

It is really interesting to see how this equation resemple to equation $(7)$, in fact, we would have arrive to equation $(7)$ from the $\mathcal{GP}$ regression formulation using the linear kernel $k(x,x')=\sigma_w^2 x^\top x$!!

Neural networks ($\mathcal{GP}$ approach)

As we said before, the NNGP approach of the paper is a neural network that follows the $\mathcal{GP}$ approach to compute the predictive posterior distribution. Let's consider the Neural Network model that we have before ($f_\omega(x)$). In order to find the joint marginal we have to integrate out the joint likelihood of eq. $(6)$. That is:

$$ p\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix} \right) =\int \mathcal{N}\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} f_\omega(X) \\f_\omega(x^\star) \end{pmatrix}, \sigma_\epsilon^2 I \right) \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) d\omega \quad (8) $$

However, as we said before, we can't apply here Bishop's 2.3.3. trick since the relation between $\omega$ and $x$ is not linear! Therefore we are trapped in the same problem as before and we can't continue!

So, what is the trick that apply them to continue? Notice that we could use the exact same trick to compute the marginal likelihood of the standard Bayesian approach. With this marginal likelihood we could also obtain the posterior of the weights $\omega$ and then to reach the predictive posterior using also the standard Bayesian approach path.

So what is their trick? They proof is that on the limit of infinite width of each of the layers of the Neural Network, the joint marginal likelihood (and the marginal likelihood too) are jointly normally distributed with zero mean (i.e. is a $\mathcal{GP}$) and covariance that can be computed using a recursive formula.

It is possible to obtain the mean and covariance function of the limit $\mathcal{GP}$ using equation $(8)$. Let's show that the distribution of equation $(8)$ has mean zero, we use here the same prior for the weights that are used in the paper ($W^l\sim\mathcal{N}\left(0,\tfrac{\sigma_w^2}{N_l}I\right)$ and $b^l\sim b^l\sim\mathcal{N}\left(0,\sigma_b^2 I\right)$

$$ \begin{aligned} \mathbb{E}\left[\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix}\right] &= \int\int \begin{pmatrix} y \\y^\star \end{pmatrix} \mathcal{N}\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} f_\omega(X) \\f_\omega(x^\star) \end{pmatrix}, \sigma_\epsilon^2 I \right) \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) dydy^\star d\omega \\ &= \int \begin{pmatrix} f_\omega(X) \\f_\omega(x^\star) \end{pmatrix} \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) d\omega \\ &= \int \begin{pmatrix} x^L(X)W^L+b^L \\ x^L(x^{\star\top})W^L+b^L \end{pmatrix} \mathcal{N}(W^L \mid 0, \sigma_w^2 I )\mathcal{N}(b^L \mid 0, \sigma_b^2 I ) \mathcal{N}(\omega^{1,...,L-1} \mid 0, \sigma_\omega^2 I ) dW^L db^L d\omega \\ &= 0 \end{aligned} $$

Since the mean of the joint distribution is zero the covariance is the second order momment. Applying the same reasoning as before, we can try to compute that covariance:

$$ \begin{aligned} \mathbb{E}\left[\begin{pmatrix} y \\y^\star \end{pmatrix} \begin{pmatrix} y^\top & y^\star \end{pmatrix} \mid \begin{pmatrix} X \\ x^{\star\top} \end{pmatrix}\right] &= \int\int \begin{pmatrix} y \\y^\star \end{pmatrix} \begin{pmatrix} y^\top & y^\star \end{pmatrix} \mathcal{N}\left(\begin{pmatrix} y \\y^\star \end{pmatrix} \mid \begin{pmatrix} f_\omega(X) \\f_\omega(x^\star) \end{pmatrix}, \sigma_\epsilon^2 I \right) \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) dydy^\star d\omega \\ &= \int \sigma_\epsilon^2 I + \begin{pmatrix} f_\omega(X) \\f_\omega(x^\star) \end{pmatrix} \begin{pmatrix} f_\omega(X)^\top & f_\omega(x^\star) \end{pmatrix} \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) d\omega \\ &= \int \begin{pmatrix} \left(x^L(X)W^L+b^L\right)\left(W^{L\top}x^{L}(X)^{\top}+b^{L\top}\right) & \left(x^L(X)W^L+b^L\right)\left(W^{L\top}x^L(x^{\star\top})^\top+b^{L\top}\right) \\ \left(x^L(x^{\star\top})W^L+b^L\right)\left(W^{L\top}x^{L}(X)^{\top}+b^{L\top}\right) & \left(x^L(x^{\star\top})W^L+b^L\right)\left(W^{L\top}x^L(x^{\star\top})^\top+b^{L\top}\right) \end{pmatrix} \mathcal{N}(\omega \mid 0, \sigma_\omega^2 I ) d\omega \\ &= \int \begin{pmatrix} \tfrac{\sigma_w^2}{N_L}x^L(X)x^{L}(X)^{\top}+\sigma_b^2 & \tfrac{\sigma_w^2}{N_L}x^L(X)x^L(x^{\star\top})^\top+\sigma_b^2 \\ \tfrac{\sigma_w^2}{N_L} x^L(x^{\star\top})x^{L}(X)^{\top}+\sigma_b^2 & \tfrac{\sigma_w^2}{N_L} x^L(x^{\star\top})x^L(x^{\star\top})^\top+\sigma_b^2 \end{pmatrix} \mathcal{N}(\omega^{1,...,L-1} \mid 0, \sigma_\omega^2 I ) d\omega \\ \end{aligned} $$

We cannot go further $x^L(X) = \phi(z^{L-1}) = \phi(W^{L-1}X^{L-1}(x) + b^{L-1})$


In [ ]: